{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sad_construct import *\n",
    "from sad_read_mnist import get_mnist\n",
    "(x_train,y_train),(x_test,y_test)=get_mnist()\n",
    "x_train=x_train.reshape(-1,1*28*28)\n",
    "x_test=x_test.reshape(-1,1*28*28)\n",
    "mnist_train_data = np.append(x_train, y_train, axis=1)\n",
    "mnist_test_data = np.append(x_test, y_test, axis=1)\n",
    "mnist_shape=(1,28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sad_network import NeuralNetwork as Network\n",
    "from sad_layer import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init para\n",
      "epoch 1:\n",
      " Loss:0.439810349 \tAccuracy: 0.9311 [0.97857, 0.98062, 0.91667, 0.93564, 0.92974, 0.8565, 0.93737, 0.93093, 0.91068, 0.92071]\n",
      "epoch 2:\n",
      " Loss:0.220160475 \tAccuracy: 0.9471 [0.9898, 0.98414, 0.93217, 0.94554, 0.93075, 0.94283, 0.94676, 0.94553, 0.92608, 0.9227]\n",
      "epoch 3:\n",
      " Loss:0.165519108 \tAccuracy: 0.9577 [0.98469, 0.97885, 0.95349, 0.96436, 0.96538, 0.93049, 0.95616, 0.95914, 0.9384, 0.94054]\n",
      "epoch 4:\n",
      " Loss:0.134408998 \tAccuracy: 0.9634 [0.98469, 0.98855, 0.96512, 0.96931, 0.96741, 0.94955, 0.97077, 0.97374, 0.93224, 0.92765]\n",
      "epoch 5:\n",
      " Loss:0.111120193 \tAccuracy: 0.9637 [0.97959, 0.98767, 0.96705, 0.98218, 0.96538, 0.9361, 0.96242, 0.94455, 0.95277, 0.95342]\n",
      "epoch 6:\n",
      " Loss:0.094798973 \tAccuracy: 0.9683 [0.98367, 0.99119, 0.96512, 0.98119, 0.98167, 0.94843, 0.96555, 0.9679, 0.96099, 0.93261]\n",
      "epoch 7:\n",
      " Loss:0.082416423 \tAccuracy: 0.9696 [0.98367, 0.98678, 0.97093, 0.95941, 0.98371, 0.97197, 0.97912, 0.96304, 0.9692, 0.92765]\n",
      "epoch 8:\n",
      " Loss:0.071902609 \tAccuracy: 0.9723 [0.98673, 0.98767, 0.97384, 0.97723, 0.97963, 0.96076, 0.96451, 0.96304, 0.97433, 0.95243]\n",
      "epoch 9:\n",
      " Loss:0.063316514 \tAccuracy: 0.9743 [0.98776, 0.99295, 0.98062, 0.97723, 0.97149, 0.96525, 0.97704, 0.9679, 0.96407, 0.9554]\n",
      "epoch 10:\n",
      " Loss:0.056646597 \tAccuracy: 0.9733 [0.98571, 0.98678, 0.96415, 0.98614, 0.97862, 0.96637, 0.9666, 0.97082, 0.97536, 0.95045]\n"
     ]
    }
   ],
   "source": [
    "fc_mnist.lossfunction=SoftmaxCrossEntropyLossLayer()\n",
    "train_and_test(\n",
    "    mynetwork=fc_mnist,\n",
    "    train_data=mnist_train_data,\n",
    "    test_data=mnist_test_data,\n",
    "    data_shape=mnist_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=10,\n",
    "    pth=\"fc_mnist_b16_l1.npy\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init para\n",
      "epoch 1:\n",
      " Loss:0.201708584 \tAccuracy: 0.8112 [0.62143, 0.98502, 0.91957, 0.0, 0.95418, 0.93386, 0.95929, 0.92412, 0.923, 0.88999]\n",
      "epoch 2:\n",
      " Loss:0.158255147 \tAccuracy: 0.83 [0.70102, 0.9859, 0.94477, 0.0, 0.96741, 0.95516, 0.96033, 0.93385, 0.94353, 0.90981]\n",
      "epoch 3:\n",
      " Loss:0.147148709 \tAccuracy: 0.8432 [0.76735, 0.98678, 0.95446, 0.0, 0.95825, 0.95628, 0.96973, 0.94747, 0.95483, 0.93954]\n",
      "epoch 4:\n",
      " Loss:0.115472867 \tAccuracy: 0.8657 [0.98571, 0.98767, 0.95252, 0.0, 0.97251, 0.9574, 0.95825, 0.94455, 0.97228, 0.9336]\n",
      "epoch 5:\n",
      " Loss:0.091598239 \tAccuracy: 0.8697 [0.9898, 0.98767, 0.95833, 0.0, 0.97454, 0.97309, 0.96347, 0.94942, 0.97125, 0.93855]\n",
      "epoch 6:\n",
      " Loss:0.087211597 \tAccuracy: 0.8709 [0.9898, 0.98678, 0.96996, 0.0, 0.97352, 0.96076, 0.97286, 0.9572, 0.96099, 0.9445]\n",
      "epoch 7:\n",
      " Loss:0.084100716 \tAccuracy: 0.8727 [0.98776, 0.98855, 0.96802, 0.0, 0.97454, 0.96525, 0.96973, 0.96401, 0.97947, 0.93756]\n",
      "epoch 8:\n",
      " Loss:0.081708747 \tAccuracy: 0.875 [0.9898, 0.98767, 0.96996, 0.0, 0.97149, 0.97534, 0.97286, 0.96109, 0.97844, 0.95243]\n",
      "epoch 9:\n",
      " Loss:0.079718983 \tAccuracy: 0.8746 [0.98673, 0.98855, 0.96415, 0.0, 0.97352, 0.96973, 0.97495, 0.96498, 0.97741, 0.95441]\n",
      "epoch 10:\n",
      " Loss:0.078079146 \tAccuracy: 0.8755 [0.98776, 0.99031, 0.96512, 0.0, 0.97556, 0.97085, 0.97704, 0.96693, 0.97639, 0.95342]\n"
     ]
    }
   ],
   "source": [
    "fc_mnist.layers.append(ReLuLayer())\n",
    "fc_mnist.lossfunction=MSELossLayer()\n",
    "train_and_test(\n",
    "    mynetwork=fc_mnist,\n",
    "    train_data=mnist_train_data,\n",
    "    test_data=mnist_test_data,\n",
    "    data_shape=mnist_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=10,\n",
    "    pth=\"fc_mnist_b16_l2.npy\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init para\n",
      "epoch 1:\n",
      " Loss:0.160295935 \tAccuracy: 0.851 [0.98265, 0.98855, 0.94186, 0.92475, 0.94501, 0.91256, 0.95825, 0.93969, 0.0, 0.88503]\n",
      "epoch 2:\n",
      " Loss:0.109355235 \tAccuracy: 0.8609 [0.98571, 0.98943, 0.93508, 0.95644, 0.94705, 0.91928, 0.97077, 0.94747, 0.0, 0.92666]\n",
      "epoch 3:\n",
      " Loss:0.098400089 \tAccuracy: 0.8681 [0.98061, 0.99031, 0.95349, 0.95941, 0.96436, 0.93834, 0.97808, 0.95039, 0.0, 0.93657]\n",
      "epoch 4:\n",
      " Loss:0.092055628 \tAccuracy: 0.8697 [0.98571, 0.99031, 0.96027, 0.95248, 0.97149, 0.94843, 0.97286, 0.95039, 0.0, 0.93657]\n",
      "epoch 5:\n",
      " loss:0.049964499 \tprocess: 42.29%\t   "
     ]
    }
   ],
   "source": [
    "fc_mnist.lossfunction=HuberLossLayer()\n",
    "train_and_test(\n",
    "    mynetwork=fc_mnist,\n",
    "    train_data=mnist_train_data,\n",
    "    test_data=mnist_test_data,\n",
    "    data_shape=mnist_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=10,\n",
    "    pth=\"fc_mnist_b16_l3.npy\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "face",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
